3 sum with multiplicity [Three Pointer, Counting with Cases]

Time: O(N^2); Space: O(N); medium

Given an integer array A, and an integer target, return the number of tuples i, j, k such that i < j < k and A[i] + A[j] + A[k] == target.

As the answer can be very large, return it modulo 10^9 + 7.

Example 1:

Input: A = [1,1,2,2,3,3,4,4,5,5], target = 8

Output: 20

Explanation:

  • Enumerating by the values (A[i], A[j], A[k]): (1, 2, 5) occurs 8 times; (1, 3, 4) occurs 8 times; (2, 2, 4) occurs 2 times; (2, 3, 3) occurs 2 times.

Example 2:

Input: A = [1,1,2,2,2,2], target = 5

Output: 12

Explanation:

  • A[i] = 1, A[j] = A[k] = 2 occurs 12 times:

  • We choose one 1 from [1,1] in 2 ways, and two 2s from [2,2,2,2] in 6 ways.

Notes:

  • 3 <= A.length <= 3000

  • 0 <= A[i] <= 100

  • 0 <= target <= 300

1. Approach Notes

The approaches described below assume some familiarity with the Two Pointer technique that can be used to solve the LeetCode problem “Two Sum”. In the problem, we have a sorted array A of unique elements, and want to know how many i < j with A[i] + A[j] == target. The idea that does it in linear time, is that for each i in increasing order, the j’s that satisfy the equation A[i] + A[j] == target are decreasing.

def solve(A, target): # Assume A already sorted i, j = 0, len(A) - 1 ans = 0 while i < j: if A[i] + A[j] < target: i += 1 elif A[i] + A[j] > target: j -= 1 else: ans += 1 i += 1 j -= 1 return ans This is not a complete explanation. For more on this problem, please review the LeetCode problem “Two Sum”.

[1]:
class Solution1(object):
    """
    Time: O(N^2), where N is the length of A.
    Space: O(1).
    """
    def threeSumMulti(self, A, target):
        """
        :type A: List[int]
        :type target: int
        :rtype: int
        """
        MOD = 10**9 + 7
        count = [0] * 101
        for x in A:
            count[x] += 1

        ans = 0

        # All different
        for x in range(101):
            for y in range(x+1, 101):
                z = target - x - y
                if y < z <= 100:
                    ans += count[x] * count[y] * count[z]
                    ans %= MOD

        # x == y
        for x in range(101):
            z = target - 2*x
            if x < z <= 100:
                ans += count[x] * (count[x] - 1) // 2 * count[z]
                ans %= MOD

        # y == z
        for x in range(101):
            if (target - x) % 2 == 0:
                y = (target - x) // 2
                if x < y <= 100:
                    ans += count[x] * count[y] * (count[y] - 1) // 2
                    ans %= MOD

        # x == y == z
        if target % 3 == 0:
            x = target // 3
            if 0 <= x <= 100:
                ans += count[x] * (count[x] - 1) * (count[x] - 2) // 6
                ans %= MOD

        return ans
[2]:
s = Solution1()
A = [1,1,2,2,3,3,4,4,5,5]
target = 8
assert s.threeSumMulti(A, target) == 20
A = [1,1,2,2,2,2]
target = 5
assert s.threeSumMulti(A, target) == 12

2. Three Pointer

Intuition and Algorithm

Sort the array. For each i, set T = target - A[i], the remaining target. We can try using a two-pointer technique to find A[j] + A[k] == T. This approach is the natural continuation of trying to make the two-pointer technique we know from previous problems, work on this problem. Because some elements are duplicated, we have to be careful. In a typical case, the target is say, 8, and we have a remaining array (A[i+1:]) of [2,2,2,2,3,3,4,4,4,5,5,5,6,6].

We can analyze this situation with cases. Whenever A[j] + A[k] == T, we should count the multiplicity of A[j] and A[k]. In this example, if A[j] == 2 and A[k] == 6, the multiplicities are 4 and 2, and the total number of pairs is 4 * 2 = 8. We then move to the remaining window A[j:k+1] of [3,3,4,4,4,5,5,5].

As a special case, if A[j] == A[k], then our manner of counting would be incorrect. If for example the remaining window is [4,4,4], there are only 3 such pairs. In general, when A[j] == A[k], we have binom(M)(2) = M*(M-1)/2 pairs (j,k) (with j < k) that satisfy A[j] + A[k] == T, where M is the multiplicity of A[j] (in this case M=3).

For more details, please see the inline comments.

[3]:
class Solution2(object):
    """
    Three Pointer
    Time: O(N^2), where N is the length of A.
    Space: O(1).
    """
    def threeSumMulti(self, A, target):
        """
        :type A: List[int]
        :type target: int
        :rtype: int
        """
        MOD = 10**9 + 7
        ans = 0
        A.sort()

        for i, x in enumerate(A):
            # We'll try to find the number of i < j < k
            # with A[j] + A[k] == T, where T = target - A[i].

            # The below is a "two sum with multiplicity".
            T = target - A[i]
            j, k = i+1, len(A) - 1

            while j < k:
                # These steps proceed as in a typical two-sum.
                if A[j] + A[k] < T:
                    j += 1
                elif A[j] + A[k] > T:
                    k -= 1
                # These steps differ:
                elif A[j] != A[k]: # We have A[j] + A[k] == T.
                    # Let's count "left": the number of A[j] == A[j+1] == A[j+2] == ...
                    # And similarly for "right".
                    left = right = 1
                    while j + 1 < k and A[j] == A[j+1]:
                        left += 1
                        j += 1
                    while k - 1 > j and A[k] == A[k-1]:
                        right += 1
                        k -= 1

                    # We contributed left * right many pairs.
                    ans += left * right
                    ans %= MOD
                    j += 1
                    k -= 1

                else:
                    # M = k - j + 1
                    # We contributed M * (M-1) / 2 pairs.
                    ans += (k-j+1) * (k-j) / 2
                    ans %= MOD
                    break

        return ans
[4]:
s = Solution2()
A = [1,1,2,2,3,3,4,4,5,5]
target = 8
assert s.threeSumMulti(A, target) == 20
A = [1,1,2,2,2,2]
target = 5
assert s.threeSumMulti(A, target) == 12

3. Counting with Cases

Intuition and Algorithm Let count[x] be the number of times that x occurs in A. For every x+y+z == target, we can try to count the correct contribution to the answer. There are a few cases: * If x, y, and z are all different, then the contribution is count[x] * count[y] * count[z]. * If x == y != z, the contribution is binom((count[x])(2)) * count[z]] * If x != y == z, the contribution is count[x] * binom((count[y])(2)) * If x == y == z, the contribution is binom((count[x])(3)) Here, binom(n)(k) denotes the binomial coefficient n!//(n-k)!k! Each case is commented in the implementations below.

[7]:
class Solution3(object):
    """
    Counting with Cases
    Time: O(N + W^2), where N is the length of A, and W is the maximum possible value of A[i].
         (Note that this solution can be adapted to be O(N^2) even in the case that W is very large.)
    Space: O(W).
    """
    def threeSumMulti(self, A, target):
        """
        :type A: List[int]
        :type target: int
        :rtype: int
        """
        MOD = 10**9 + 7
        count = [0] * 101
        for x in A:
            count[x] += 1

        ans = 0

        # All different
        for x in range(101):
            for y in range(x+1, 101):
                z = target - x - y
                if y < z <= 100:
                    ans += count[x] * count[y] * count[z]
                    ans %= MOD

        # x == y
        for x in range(101):
            z = target - 2*x
            if x < z <= 100:
                ans += count[x] * (count[x] - 1) // 2 * count[z]
                ans %= MOD

        # y == z
        for x in range(101):
            if (target - x) % 2 == 0:
                y = (target - x) // 2
                if x < y <= 100:
                    ans += count[x] * count[y] * (count[y] - 1) // 2
                    ans %= MOD

        # x == y == z
        if target % 3 == 0:
            x = target // 3
            if 0 <= x <= 100:
                ans += count[x] * (count[x] - 1) * (count[x] - 2) // 6
                ans %= MOD

        return ans
[8]:
s = Solution3()
A = [1,1,2,2,3,3,4,4,5,5]
target = 8
assert s.threeSumMulti(A, target) == 20
A = [1,1,2,2,2,2]
target = 5
assert s.threeSumMulti(A, target) == 12
[9]:
import collections
import itertools

class Solution4(object):
    """
    Time:  O(N^2), N is the number of disctinct A[i]
    Space: O(N)
    """
    def threeSumMulti(self, A, target):
        """
        :type A: List[int]
        :type target: int
        :rtype: int
        """
        count = collections.Counter(A)
        result = 0
        for i, j in itertools.combinations_with_replacement(count, 2):
            k = target - i - j
            if i == j == k:
                result += count[i] * (count[i]-1) * (count[i]-2) // 6
            elif i == j != k:
                result += count[i] * (count[i]-1) // 2 * count[k]
            elif max(i, j) < k:
                result += count[i] * count[j] * count[k]
        return result % (10**9 + 7)
[10]:
s = Solution4()
A = [1,1,2,2,3,3,4,4,5,5]
target = 8
assert s.threeSumMulti(A, target) == 20
A = [1,1,2,2,2,2]
target = 5
assert s.threeSumMulti(A, target) == 12